{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "EgiF12Hf1Dhs"
   },
   "source": [
    "This notebook provides examples to go along with the [textbook](http://manipulation.csail.mit.edu/rl.html).  I recommend having both windows open, side-by-side!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "import torch\n",
    "from pydrake.all import StartMeshcat\n",
    "\n",
    "from manipulation.utils import LoadDataResource, running_as_notebook\n",
    "import manipulation.envs.box_flipup  # no-member\n",
    "from manipulation.utils import RenderDiagram\n",
    "from manipulation.meshcat_utils import plot_surface\n",
    "\n",
    "from psutil import cpu_count\n",
    "\n",
    "num_cpu = int(cpu_count() / 2) if running_as_notebook else 2\n",
    "\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv\n",
    "from stable_baselines3.common.env_util import make_vec_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "meshcat = StartMeshcat()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RL for box flip-up\n",
    "\n",
    "## State-feedback policy via PPO (with stiffness control)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "observations = \"state\"\n",
    "time_limit = 10 if running_as_notebook else 0.5\n",
    "\n",
    "# Note: Models saved in stable baselines are version specific.  This one\n",
    "# requires python3.8 (and cloudpickle==1.6.0).\n",
    "zip = f\"box_flipup_ppo_{observations}.zip\"\n",
    "\n",
    "\n",
    "# Use a callback so that the forked process imports the environment.\n",
    "def make_boxflipup():\n",
    "    import manipulation.envs.box_flipup\n",
    "\n",
    "    return gym.make(\n",
    "        \"BoxFlipUp-v0\", observations=observations, time_limit=time_limit\n",
    "    )\n",
    "\n",
    "\n",
    "env = make_vec_env(\n",
    "    make_boxflipup,\n",
    "    n_envs=num_cpu,\n",
    "    seed=0,\n",
    "    vec_env_cls=SubprocVecEnv if running_as_notebook else DummyVecEnv,\n",
    ")\n",
    "\n",
    "use_pretrained_model = True\n",
    "if use_pretrained_model:\n",
    "    # TODO(russt): Save a trained model that works on Deepnote.\n",
    "    model = PPO.load(LoadDataResource(zip), env)\n",
    "elif running_as_notebook:\n",
    "    # This is a relatively small amount of training.  See rl/train_boxflipup.py\n",
    "    # for a version that runs the heavyweight version with multiprocessing.\n",
    "    model = PPO(\"MlpPolicy\", env, verbose=1)\n",
    "    model.learn(total_timesteps=100000)\n",
    "else:\n",
    "    # For testing this notebook, we simply want to make sure that the code runs.\n",
    "    model = PPO(\"MlpPolicy\", env, n_steps=4, n_epochs=2, batch_size=8)\n",
    "    model.learn(total_timesteps=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make a version of the env with meshcat.\n",
    "env = gym.make(\"BoxFlipUp-v0\", meshcat=meshcat, observations=observations)\n",
    "\n",
    "obs, _ = env.reset()\n",
    "meshcat.StartRecording()\n",
    "for i in range(500 if running_as_notebook else 5):\n",
    "    action, _state = model.predict(obs, deterministic=True)\n",
    "    obs, reward, terminated, truncated, info = env.step(action)\n",
    "    env.render()\n",
    "    if terminated:\n",
    "        obs, _ = env.reset()\n",
    "meshcat.PublishRecording()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs, _ = env.reset()\n",
    "Q, Qdot = np.meshgrid(np.arange(0, np.pi, 0.05), np.arange(-2, 2, 0.05))\n",
    "# TODO(russt): tensorize this...\n",
    "V = 0 * Q\n",
    "for i in range(Q.shape[0]):\n",
    "    for j in range(Q.shape[1]):\n",
    "        obs[2] = Q[i, j]\n",
    "        obs[7] = Qdot[i, j]\n",
    "        with torch.no_grad():\n",
    "            V[i, j] = (\n",
    "                model.policy.predict_values(\n",
    "                    model.policy.obs_to_tensor(obs)[0]\n",
    "                )[0]\n",
    "                .cpu()\n",
    "                .numpy()[0]\n",
    "            )\n",
    "V = V - np.min(np.min(V))\n",
    "V = V / np.max(np.max(V))\n",
    "\n",
    "meshcat.Delete()\n",
    "meshcat.ResetRenderMode()\n",
    "plot_surface(meshcat, \"Critic\", Q, Qdot, V, wireframe=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make(\"BoxFlipUp-v0\")\n",
    "\n",
    "RenderDiagram(env.simulator.get_system(), max_depth=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Robotic Manipulation - Geometric Pose Estimation.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3.10.8 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "metadata": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
